import scalevi.var_dists.var_dists_base as var_dists_base
import scalevi.distributions.scale_transforms as scale_transforms
import scalevi.distributions.distributions as dists
import jax.numpy as np
import abc

class BaselineVarDist(var_dists_base.VarDist):
    def __init__(
                self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        self.D_par = D_par
        self.D_kid = D_kid
        self.N_chunk = N_chunk
        self.scale_transform=scale_transform
        super(BaselineVarDist, self).__init__(D_par + N_chunk*D_kid)
    @abc.abstractmethod
    def base_dist(self, params):
        pass
    def sample(self, rng_key, params, chunk):
        s = self.base_dist(**self.get_params(params)).sample(rng_key)
        return s[:self.D_par], s[self.D_par:].reshape(self.N_chunk, self.D_kid)
        
    def log_prob(self, z, params, chunk):
        return self.base_dist(**self.get_params(params)).log_prob(
            np.concatenate([z[0], z[1].ravel()])
        )

class Gaussian(BaselineVarDist):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(Gaussian, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)

    def base_dist(self, params):
        return dists.MultivariateNormal(params)

    def initial_params(self):
        return {
            "μ": np.zeros(self.z_dim),
            "L": self.scale_transform.inverse(np.eye(self.z_dim))
            }

    def get_params(self, params):
        return {
            "loc": params['μ'], 
            "scale_tril": self.scale_transform.forward(params['L'])
            }

class Diagonal(BaselineVarDist):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(Diagonal, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)
    def base_dist(self, params):
        return dists.DiagonalNormal(params)

    def initial_params(self):
        return {"μ": np.zeros(self.z_dim),
                "L": self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
                }

    def get_params(self, params):
        return {
            "mu": params['μ'], 
            "cov": self.scale_transform.forward_diag_transform(params['L'])**2
            }

class LowRankGauss(BaselineVarDist):

    def __init__(
                self, N_chunk, D_par, D_kid,
                r,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        self.r = r
        super(LowRankGauss, self).__init__(N_chunk, D_par, D_kid, scale_transform)

    def base_dist(self, params):
        return dists.LowRankMultivariateNormal(params)
    
    def initial_params(self):
        return {
            "μ" : np.zeros(self.z_dim),
            "Λ" : np.zeros((self.z_dim, self.r)),
            "λ" : self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
        }

    def get_params(self, params):
        return {
            "loc": params['μ'],
            "cov_factor": params['Λ'],
            "cov_diag":  self.scale_transform.forward_diag_transform(params['λ'])
        }

class BaselineVarDistWithSampleEval(var_dists_base.VarDistWithSampleEval):
    def __init__(
                self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        self.D_par = D_par
        self.D_kid = D_kid
        self.N_chunk = N_chunk
        self.scale_transform=scale_transform
        super(BaselineVarDistWithSampleEval, self).__init__(D_par + N_chunk*D_kid)

    @abc.abstractmethod
    def base_dist(self, **kwargs):
        pass

    def sample_and_log_prob(self, rng_key, params, chunk, **kwargs):
        s, log_prob = self.base_dist(**self.get_params(params)).sample_and_log_prob(rng_key)
        return (s[:self.D_par], s[self.D_par:].reshape(self.N_chunk, self.D_kid)), log_prob
    
    def sample(self, rng_key, params, chunk, **kwargs):
        s = self.base_dist(**self.get_params(params)).sample(rng_key)
        return s[:self.D_par], s[self.D_par:].reshape(self.N_chunk, self.D_kid)
        
    def log_prob(self, z, params, chunk, **kwargs):
        return self.base_dist(**self.get_params(params)).log_prob(
            np.concatenate([z[0], z[1].ravel()])
        )
        
class GaussianWithSampleEval(BaselineVarDistWithSampleEval):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(GaussianWithSampleEval, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)

    def base_dist(self, **kwargs):
        return dists.CustomMultivariateNormal(**kwargs)

    def initial_params(self):
        return {
            "μ": np.zeros(self.z_dim),
            "L": dists.util.matrix_to_tril_vec(
                    self.scale_transform.inverse(np.eye(self.z_dim)))
            }

    def get_params(self, params):
        return {
            "loc": params['μ'], 
            "scale_tril": self.scale_transform.forward(
                                dists.util.vec_to_tril_matrix(
                                    params['L']))
            }

class DiagonalWithSampleEval(BaselineVarDistWithSampleEval):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(DiagonalWithSampleEval, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)
    def base_dist(self, **kwargs):
        return dists.CustomDiagonalNormal(**kwargs)

    def initial_params(self):
        return {"μ": np.zeros(self.z_dim),
                "σ": self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
                }

    def get_params(self, params):
        return {
            "mu": params['μ'], 
            "sig": self.scale_transform.forward_diag_transform(params['σ'])
            }

class DiagonalWithSampleEval_v2(BaselineVarDistWithSampleEval):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(DiagonalWithSampleEval_v2, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)
    def base_dist(self, **kwargs):
        return dists.CustomMultivariateNormal(**kwargs)

    def initial_params(self):
        return {"μ": np.zeros(self.z_dim),
                "σ": self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
                }

    def get_params(self, params):
        return {
            "loc": params['μ'], 
            "scale_tril": np.diag(self.scale_transform.forward_diag_transform(params['σ']))
            }

class BlockGaussianWithSampleEval(BaselineVarDistWithSampleEval):
    def __init__(self, N_chunk, D_par, D_kid,
                scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        super(BlockGaussianWithSampleEval, self).__init__(N_chunk, D_par, D_kid,
                                        scale_transform=scale_transform)
    def base_dist(self, **kwargs):
        return dists.CustomBlockMultivariateNormal(**kwargs)

    def initial_params(self):
        return {"μθ": np.zeros(self.D_par),
                "Lθ": dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(np.eye(self.D_par))),
                "μz": np.zeros(self.N_chunk*self.D_kid),
                "Lz": dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(np.eye(self.N_chunk*self.D_kid))),
                }

    def get_params(self, params):
        return {
            "loc_0": params['μθ'], 
            "scale_tril_0": self.scale_transform.forward(
                                dists.util.vec_to_tril_matrix(
                                    params['Lθ'])),
            "loc_1": params['μz'], 
            "scale_tril_1": self.scale_transform.forward(
                                dists.util.vec_to_tril_matrix(
                                    params['Lz'])),
            }

# class Diagonal(BaselineVarDist):
#     def __init__(self, N_chunk, D_par, D_kid,
#                 scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
#         super(Diagonal, self).__init__(N_chunk, D_par, D_kid,
#                                         scale_transform=scale_transform)
#     def base_dist(self, params):
#         return dists.DiagonalNormal(params)

#     def initial_params(self):
#         return {"μ": np.zeros(self.z_dim),
#                 "L": self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
#                 }

#     def get_params(self, params):
#         return {
#             "mu": params['μ'], 
#             "cov": self.scale_transform.forward_diag_transform(params['L'])**2
#             }


# class LowRankGauss(BaselineVarDist):

#     def __init__(
#                 self, N_chunk, D_par, D_kid,
#                 r,
#                 scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
#         self.r = r
#         super(LowRankGauss, self).__init__(N_chunk, D_par, D_kid, scale_transform)

#     def base_dist(self, params):
#         return dists.LowRankMultivariateNormal(params)
    
#     def initial_params(self):
#         return {
#             "μ" : np.zeros(self.z_dim),
#             "Λ" : np.zeros((self.z_dim, self.r)),
#             "λ" : self.scale_transform.inverse_diag_transform(np.ones(self.z_dim))
#         }

#     def get_params(self, params):
#         return {
#             "loc": params['μ'],
#             "cov_factor": params['Λ'],
#             "cov_diag":  self.scale_transform.forward_diag_transform(params['λ'])
#         }
